from typing import Dict, Optional

from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, TAGGING_OPS, Mapper

OP_NAME = "query_topic_detection_mapper"


@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class QueryTopicDetectionMapper(Mapper):
    """Predicts the topic label and its corresponding score for a given query. The input is
    taken from the specified query key. The output, which includes the predicted topic label
    and its score, is stored in the 'query_topic_label' and 'query_topic_label_score' fields
    of the Data-Juicer meta field. This operator uses a Hugging Face model for topic
    classification. If a Chinese to English translation model is provided, it will first
    translate the query from Chinese to English before predicting the topic.

    - Uses a Hugging Face model for topic classification.
    - Optionally translates Chinese queries to English using another Hugging Face
    model.
    - Stores the predicted topic label in 'query_topic_label'.
    - Stores the corresponding score in 'query_topic_label_score'."""

    _accelerator = "cuda"
    _batched_op = True

    def __init__(
        self,
        hf_model: str = "dstefa/roberta-base_topic_classification_nyt_news",  # noqa: E501 E131
        zh_to_en_hf_model: Optional[str] = "Helsinki-NLP/opus-mt-zh-en",
        model_params: Dict = {},
        zh_to_en_model_params: Dict = {},
        *,
        label_key: str = MetaKeys.query_topic_label,
        score_key: str = MetaKeys.query_topic_score,
        **kwargs,
    ):
        """
        Initialization method.

        :param hf_model: Huggingface model ID to predict topic label.
        :param zh_to_en_hf_model: Translation model from Chinese to English.
            If not None, translate the query from Chinese to English.
        :param model_params: model param for hf_model.
        :param zh_to_en_model_params: model param for zh_to_hf_model.
        :param label_key: The key name in the meta field to store the
            output label. It is 'query_topic_label' in default.
        :param score_key: The key name in the meta field to store the
            corresponding label score. It is 'query_topic_label_score'
            in default.
        :param kwargs: Extra keyword arguments.
        """
        super().__init__(**kwargs)

        self.label_key = label_key
        self.score_key = score_key

        self.model_key = prepare_model(
            model_type="huggingface",
            pretrained_model_name_or_path=hf_model,
            return_pipe=True,
            pipe_task="text-classification",
            **model_params,
        )

        if zh_to_en_hf_model is not None:
            self.zh_to_en_model_key = prepare_model(
                model_type="huggingface",
                pretrained_model_name_or_path=zh_to_en_hf_model,
                return_pipe=True,
                pipe_task="translation",
                **zh_to_en_model_params,
            )
        else:
            self.zh_to_en_model_key = None

    def process_batched(self, samples, rank=None):
        metas = samples[Fields.meta]
        if self.label_key in metas[0] and self.score_key in metas[0]:
            return samples

        queries = samples[self.query_key]

        if self.zh_to_en_model_key is not None:
            translator, _ = get_model(self.zh_to_en_model_key, rank, self.use_cuda())
            results = translator(queries)
            queries = [item["translation_text"] for item in results]

        classifier, _ = get_model(self.model_key, rank, self.use_cuda())
        results = classifier(queries)
        labels = [r["label"] for r in results]
        scores = [r["score"] for r in results]

        for i in range(len(metas)):
            metas[i][self.label_key] = labels[i]
            metas[i][self.score_key] = scores[i]

        return samples
